{
 "cells": [
  {
   "cell_type": "markdown",
   "source": [
    "# 训练像素模型\n",
    "用这个文件可以训练出需要使用的光谱像素点模型"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%% md\n"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "outputs": [
    {
     "ename": "NotImplementedError",
     "evalue": "开发时和机器上部署的路径不同，请注意选择rgb_tobacco_model_path、rgb_background_model_path、ai_path后删除本行",
     "output_type": "error",
     "traceback": [
      "\u001B[1;31m---------------------------------------------------------------------------\u001B[0m",
      "\u001B[1;31mNotImplementedError\u001B[0m                       Traceback (most recent call last)",
      "Input \u001B[1;32mIn [1]\u001B[0m, in \u001B[0;36m<cell line: 3>\u001B[1;34m()\u001B[0m\n\u001B[0;32m      1\u001B[0m \u001B[38;5;28;01mimport\u001B[39;00m \u001B[38;5;21;01mnumpy\u001B[39;00m \u001B[38;5;28;01mas\u001B[39;00m \u001B[38;5;21;01mnp\u001B[39;00m\n\u001B[0;32m      2\u001B[0m \u001B[38;5;28;01mimport\u001B[39;00m \u001B[38;5;21;01mpickle\u001B[39;00m\n\u001B[1;32m----> 3\u001B[0m \u001B[38;5;28;01mfrom\u001B[39;00m \u001B[38;5;21;01mutils\u001B[39;00m \u001B[38;5;28;01mimport\u001B[39;00m read_envi_ascii\n\u001B[0;32m      4\u001B[0m \u001B[38;5;66;03m# from config import Config\u001B[39;00m\n\u001B[0;32m      5\u001B[0m \u001B[38;5;28;01mfrom\u001B[39;00m \u001B[38;5;21;01mmodels\u001B[39;00m \u001B[38;5;28;01mimport\u001B[39;00m ManualTree\n",
      "File \u001B[1;32m~\\OneDrive - macrosolid\\PycharmProjects\\tobacco_color\\utils\\__init__.py:20\u001B[0m, in \u001B[0;36m<module>\u001B[1;34m\u001B[0m\n\u001B[0;32m     17\u001B[0m \u001B[38;5;28;01mfrom\u001B[39;00m \u001B[38;5;21;01mmatplotlib\u001B[39;00m \u001B[38;5;28;01mimport\u001B[39;00m pyplot \u001B[38;5;28;01mas\u001B[39;00m plt\n\u001B[0;32m     18\u001B[0m \u001B[38;5;28;01mimport\u001B[39;00m \u001B[38;5;21;01mre\u001B[39;00m\n\u001B[1;32m---> 20\u001B[0m \u001B[38;5;28;01mfrom\u001B[39;00m \u001B[38;5;21;01mconfig\u001B[39;00m \u001B[38;5;28;01mimport\u001B[39;00m Config\n\u001B[0;32m     23\u001B[0m \u001B[38;5;28;01mdef\u001B[39;00m \u001B[38;5;21mnatural_sort\u001B[39m(l):\n\u001B[0;32m     24\u001B[0m     \u001B[38;5;124;03m\"\"\"\u001B[39;00m\n\u001B[0;32m     25\u001B[0m \u001B[38;5;124;03m    自然排序\u001B[39;00m\n\u001B[0;32m     26\u001B[0m \u001B[38;5;124;03m    :param l: 待排序\u001B[39;00m\n\u001B[0;32m     27\u001B[0m \u001B[38;5;124;03m    :return:\u001B[39;00m\n\u001B[0;32m     28\u001B[0m \u001B[38;5;124;03m    \"\"\"\u001B[39;00m\n",
      "File \u001B[1;32m~\\OneDrive - macrosolid\\PycharmProjects\\tobacco_color\\config.py:6\u001B[0m, in \u001B[0;36m<module>\u001B[1;34m\u001B[0m\n\u001B[0;32m      1\u001B[0m \u001B[38;5;28;01mimport\u001B[39;00m \u001B[38;5;21;01mos\u001B[39;00m\n\u001B[0;32m      3\u001B[0m \u001B[38;5;28;01mimport\u001B[39;00m \u001B[38;5;21;01mnumpy\u001B[39;00m \u001B[38;5;28;01mas\u001B[39;00m \u001B[38;5;21;01mnp\u001B[39;00m\n\u001B[1;32m----> 6\u001B[0m \u001B[38;5;28;01mclass\u001B[39;00m \u001B[38;5;21;01mConfig\u001B[39;00m:\n\u001B[0;32m      7\u001B[0m     \u001B[38;5;66;03m# 文件相关参数\u001B[39;00m\n\u001B[0;32m      8\u001B[0m     nRows, nCols, nBands \u001B[38;5;241m=\u001B[39m \u001B[38;5;241m256\u001B[39m, \u001B[38;5;241m1024\u001B[39m, \u001B[38;5;241m22\u001B[39m\n\u001B[0;32m      9\u001B[0m     nRgbRows, nRgbCols, nRgbBands \u001B[38;5;241m=\u001B[39m \u001B[38;5;241m1024\u001B[39m, \u001B[38;5;241m4096\u001B[39m, \u001B[38;5;241m3\u001B[39m\n",
      "File \u001B[1;32m~\\OneDrive - macrosolid\\PycharmProjects\\tobacco_color\\config.py:31\u001B[0m, in \u001B[0;36mConfig\u001B[1;34m()\u001B[0m\n\u001B[0;32m     28\u001B[0m spec_size_threshold \u001B[38;5;241m=\u001B[39m \u001B[38;5;241m3\u001B[39m\n\u001B[0;32m     30\u001B[0m \u001B[38;5;66;03m# rgb模型参数\u001B[39;00m\n\u001B[1;32m---> 31\u001B[0m \u001B[38;5;28;01mraise\u001B[39;00m \u001B[38;5;167;01mNotImplementedError\u001B[39;00m(\u001B[38;5;124m\"\u001B[39m\u001B[38;5;124m开发时和机器上部署的路径不同，请注意选择rgb_tobacco_model_path、rgb_background_model_path、ai_path后删除本行\u001B[39m\u001B[38;5;124m\"\u001B[39m)\n\u001B[0;32m     32\u001B[0m \u001B[38;5;66;03m# rgb_tobacco_model_path = r\"weights/tobacco_dt_2022-08-27_14-43.model\"  # 开发时的路径\u001B[39;00m\n\u001B[0;32m     33\u001B[0m \u001B[38;5;66;03m# rgb_tobacco_model_path = r\"/home/dt/tobacco-color/weights/tobacco_dt_2022-08-27_14-43.model\"  # 机器上部署的路径\u001B[39;00m\n\u001B[0;32m     34\u001B[0m \u001B[38;5;66;03m# rgb_background_model_path = r\"weights/background_dt_2022-08-22_22-15.model\"  # 开发时的路径\u001B[39;00m\n\u001B[0;32m     35\u001B[0m \u001B[38;5;66;03m# rgb_background_model_path = r\"/home/dt/tobacco-color/weights/background_dt_2022-08-22_22-15.model\"  # 机器上部署的路径\u001B[39;00m\n\u001B[0;32m     36\u001B[0m threshold_low, threshold_high \u001B[38;5;241m=\u001B[39m \u001B[38;5;241m10\u001B[39m, \u001B[38;5;241m230\u001B[39m\n",
      "\u001B[1;31mNotImplementedError\u001B[0m: 开发时和机器上部署的路径不同，请注意选择rgb_tobacco_model_path、rgb_background_model_path、ai_path后删除本行"
     ]
    }
   ],
   "source": [
    "import numpy as np\n",
    "import pickle\n",
    "from utils import read_envi_ascii\n",
    "from config import Config\n",
    "from models import ManualTree"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n"
    }
   }
  },
  {
   "cell_type": "markdown",
   "source": [
    "# 一些变量"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%% md\n"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "outputs": [],
   "source": [
    "data_path = r'data/envi20220802.txt'\n",
    "name_dict = {'tobacco': 1, 'yantou':2, 'kazhi':3, 'bomo':4, 'jiaodai':5, 'background':0}"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n"
    }
   }
  },
  {
   "cell_type": "markdown",
   "source": [
    "# 构建数据集"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%% md\n"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "outputs": [],
   "source": [
    "data = read_envi_ascii(data_path)"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "zibian (569, 448)\n",
      "tobacco (1457, 448)\n",
      "yantou (354, 448)\n",
      "kazhi (449, 448)\n",
      "bomo (1154, 448)\n",
      "jiaodai (566, 448)\n",
      "background (1235, 448)\n"
     ]
    }
   ],
   "source": [
    "_ = [print(class_name, d.shape) for class_name, d in data.items()]"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "outputs": [],
   "source": [
    "data_x = [d for class_name, d in data.items() if class_name in name_dict.keys()]\n",
    "data_y = [np.ones((d.shape[0], ))*name_dict[class_name] for class_name, d in data.items() if class_name in name_dict.keys()]\n",
    "data_x, data_y = np.concatenate(data_x), np.concatenate(data_y)"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n"
    }
   }
  },
  {
   "cell_type": "markdown",
   "source": [
    "## 取出需要的22个特征谱段"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%% md\n"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "这些是现在的数据:  (5215, 448) (5215,)\n",
      "截取其中需要的部分后:  (5215, 22) (5215,)\n"
     ]
    }
   ],
   "source": [
    "print(\"这些是现在的数据: \", data_x.shape, data_y.shape)\n",
    "data_x_cut = data_x[..., Config.bands]\n",
    "print(\"截取其中需要的部分后: \", data_x_cut.shape, data_y.shape)"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n"
    }
   }
  },
  {
   "cell_type": "markdown",
   "source": [
    "## 进行样本平衡"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%% md\n"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "这是重采样后的数据:  (8742, 22) (8742,)\n"
     ]
    }
   ],
   "source": [
    "from imblearn.over_sampling import RandomOverSampler\n",
    "ros = RandomOverSampler(random_state=0)\n",
    "x_resampled, y_resampled = ros.fit_resample(data_x_cut, data_y)\n",
    "print('这是重采样后的数据: ', x_resampled.shape, y_resampled.shape)"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n"
    }
   }
  },
  {
   "cell_type": "markdown",
   "source": [
    "# 进行模型训练\n",
    "分出一部分数据进行训练"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%% md\n"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "outputs": [],
   "source": [
    "from models import DecisionTree\n",
    "from sklearn.model_selection import train_test_split"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "outputs": [],
   "source": [
    "train_x, test_x, train_y, test_y = train_test_split(x_resampled, y_resampled, test_size=0.2)"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "outputs": [],
   "source": [
    "tree = DecisionTree(class_weight={1:20})"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "outputs": [],
   "source": [
    "tree = tree.fit(train_x, train_y)"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n"
    }
   }
  },
  {
   "cell_type": "markdown",
   "source": [
    "# 模型评估"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%% md\n"
    }
   }
  },
  {
   "cell_type": "markdown",
   "source": [
    "## 多分类精度"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%% md\n"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "outputs": [],
   "source": [
    "pred_y = tree.predict(test_x)"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "              precision    recall  f1-score   support\n",
      "\n",
      "         0.0       1.00      1.00      1.00       312\n",
      "         1.0       0.98      0.97      0.98       289\n",
      "         2.0       0.97      1.00      0.99       312\n",
      "         3.0       0.95      0.99      0.97       278\n",
      "         4.0       0.98      0.92      0.95       288\n",
      "         5.0       0.98      0.98      0.98       270\n",
      "\n",
      "    accuracy                           0.98      1749\n",
      "   macro avg       0.98      0.98      0.98      1749\n",
      "weighted avg       0.98      0.98      0.98      1749\n",
      "\n"
     ]
    }
   ],
   "source": [
    "from sklearn.metrics import classification_report\n",
    "\n",
    "print(classification_report(y_pred=pred_y, y_true=test_y))"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n"
    }
   }
  },
  {
   "cell_type": "markdown",
   "source": [
    "## 二分类精度"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%% md\n"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "outputs": [],
   "source": [
    "test_y[test_y <= 1] = 0\n",
    "test_y[test_y > 1] = 1\n",
    "pred_y = tree.predict_bin(test_x)"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "              precision    recall  f1-score   support\n",
      "\n",
      "         0.0       1.00      1.00      1.00       598\n",
      "         1.0       1.00      1.00      1.00      1151\n",
      "\n",
      "    accuracy                           1.00      1749\n",
      "   macro avg       1.00      1.00      1.00      1749\n",
      "weighted avg       1.00      1.00      1.00      1749\n",
      "\n"
     ]
    }
   ],
   "source": [
    "print(classification_report(y_true=pred_y, y_pred=pred_y))"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n"
    }
   }
  },
  {
   "cell_type": "markdown",
   "source": [
    "# 模型保存"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%% md\n"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "outputs": [],
   "source": [
    "import datetime\n",
    "\n",
    "path = datetime.datetime.now().strftime(f\"models/pixel_%Y-%m-%d_%H-%M.model\")\n",
    "with open(path, 'wb') as f:\n",
    "    pickle.dump(tree, f)"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "outputs": [],
   "source": [
    "from models import DecisionTree\n",
    "from sklearn.model_selection import train_test_split"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "outputs": [],
   "source": [
    "train_x, test_x, train_y, test_y = train_test_split(x_resampled, y_resampled, test_size=0.2)"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "outputs": [],
   "source": [
    "tree = DecisionTree(class_weight={1:20})"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "outputs": [],
   "source": [
    "tree = tree.fit(train_x, train_y)"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n"
    }
   }
  },
  {
   "cell_type": "markdown",
   "source": [
    "# 模型评估"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%% md\n"
    }
   }
  },
  {
   "cell_type": "markdown",
   "source": [
    "## 多分类精度"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%% md\n"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "outputs": [],
   "source": [
    "pred_y = tree.predict(test_x)"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "              precision    recall  f1-score   support\n",
      "\n",
      "         0.0       1.00      1.00      1.00       304\n",
      "         1.0       0.97      0.98      0.97       275\n",
      "         2.0       0.98      1.00      0.99       297\n",
      "         3.0       0.95      0.99      0.97       293\n",
      "         4.0       0.98      0.91      0.95       316\n",
      "         5.0       0.98      0.98      0.98       264\n",
      "\n",
      "    accuracy                           0.98      1749\n",
      "   macro avg       0.98      0.98      0.98      1749\n",
      "weighted avg       0.98      0.98      0.98      1749\n",
      "\n"
     ]
    }
   ],
   "source": [
    "from sklearn.metrics import classification_report\n",
    "\n",
    "print(classification_report(y_pred=pred_y, y_true=test_y))"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n"
    }
   }
  },
  {
   "cell_type": "markdown",
   "source": [
    "## 二分类精度"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%% md\n"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "outputs": [],
   "source": [
    "test_y[test_y <= 1] = 0\n",
    "test_y[test_y > 1] = 1\n",
    "pred_y = tree.predict_bin(test_x)"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "              precision    recall  f1-score   support\n",
      "\n",
      "         0.0       1.00      1.00      1.00       582\n",
      "         1.0       1.00      1.00      1.00      1167\n",
      "\n",
      "    accuracy                           1.00      1749\n",
      "   macro avg       1.00      1.00      1.00      1749\n",
      "weighted avg       1.00      1.00      1.00      1749\n",
      "\n"
     ]
    }
   ],
   "source": [
    "print(classification_report(y_true=pred_y, y_pred=pred_y))"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n"
    }
   }
  },
  {
   "cell_type": "markdown",
   "source": [
    "# 模型保存"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%% md\n"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "outputs": [],
   "source": [
    "import datetime\n",
    "\n",
    "path = datetime.datetime.now().strftime(f\"models/pixel_%Y-%m-%d_%H-%M.model\")\n",
    "with open(path, 'wb') as f:\n",
    "    pickle.dump(tree, f)"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n"
    }
   }
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 2
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython2",
   "version": "2.7.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 0
}